/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.state.api;

import org.apache.flink.api.common.JobID;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.time.Deadline;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.client.program.ClusterClient;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.runtime.state.memory.MemoryStateBackend;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.co.BroadcastProcessFunction;
import org.apache.flink.streaming.api.functions.sink.DiscardingSink;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.test.util.AbstractTestBase;
import org.apache.flink.util.AbstractID;
import org.apache.flink.util.Collector;

import org.junit.Assert;
import org.junit.Test;

import java.io.IOException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

/** IT case for reading state. */
public abstract class SavepointReaderITTestBase extends AbstractTestBase {
    static final String UID = "stateful-operator";

    static final String LIST_NAME = "list";

    static final String UNION_NAME = "union";

    static final String BROADCAST_NAME = "broadcast";

    private final ListStateDescriptor<Integer> list;

    private final ListStateDescriptor<Integer> union;

    private final MapStateDescriptor<Integer, String> broadcast;

    SavepointReaderITTestBase(
            ListStateDescriptor<Integer> list,
            ListStateDescriptor<Integer> union,
            MapStateDescriptor<Integer, String> broadcast) {

        this.list = list;
        this.union = union;
        this.broadcast = broadcast;
    }

    @Test
    public void testOperatorStateInputFormat() throws Exception {
        StreamExecutionEnvironment streamEnv = StreamExecutionEnvironment.getExecutionEnvironment();
        streamEnv.setParallelism(4);

        DataStream<Integer> data = streamEnv.addSource(new SavepointSource()).rebalance();

        data.connect(data.broadcast(broadcast))
                .process(new StatefulOperator(list, union, broadcast))
                .uid(UID)
                .addSink(new DiscardingSink<>());

        JobGraph jobGraph = streamEnv.getStreamGraph().getJobGraph();

        String savepoint = takeSavepoint(jobGraph);

        ExecutionEnvironment batchEnv = ExecutionEnvironment.getExecutionEnvironment();

        verifyListState(savepoint, batchEnv);

        verifyUnionState(savepoint, batchEnv);

        verifyBroadcastState(savepoint, batchEnv);
    }

    abstract DataSet<Integer> readListState(ExistingSavepoint savepoint) throws IOException;

    abstract DataSet<Integer> readUnionState(ExistingSavepoint savepoint) throws IOException;

    abstract DataSet<Tuple2<Integer, String>> readBroadcastState(ExistingSavepoint savepoint)
            throws IOException;

    private void verifyListState(String path, ExecutionEnvironment batchEnv) throws Exception {
        ExistingSavepoint savepoint = Savepoint.load(batchEnv, path, new MemoryStateBackend());
        List<Integer> listResult = readListState(savepoint).collect();
        listResult.sort(Comparator.naturalOrder());

        Assert.assertEquals(
                "Unexpected elements read from list state",
                SavepointSource.getElements(),
                listResult);
    }

    private void verifyUnionState(String path, ExecutionEnvironment batchEnv) throws Exception {
        ExistingSavepoint savepoint = Savepoint.load(batchEnv, path, new MemoryStateBackend());
        List<Integer> unionResult = readUnionState(savepoint).collect();
        unionResult.sort(Comparator.naturalOrder());

        Assert.assertEquals(
                "Unexpected elements read from union state",
                SavepointSource.getElements(),
                unionResult);
    }

    private void verifyBroadcastState(String path, ExecutionEnvironment batchEnv) throws Exception {
        ExistingSavepoint savepoint = Savepoint.load(batchEnv, path, new MemoryStateBackend());
        List<Tuple2<Integer, String>> broadcastResult = readBroadcastState(savepoint).collect();

        List<Integer> broadcastStateKeys =
                broadcastResult.stream()
                        .map(entry -> entry.f0)
                        .sorted(Comparator.naturalOrder())
                        .collect(Collectors.toList());

        List<String> broadcastStateValues =
                broadcastResult.stream()
                        .map(entry -> entry.f1)
                        .sorted(Comparator.naturalOrder())
                        .collect(Collectors.toList());

        Assert.assertEquals(
                "Unexpected element in broadcast state keys",
                SavepointSource.getElements(),
                broadcastStateKeys);

        Assert.assertEquals(
                "Unexpected element in broadcast state values",
                SavepointSource.getElements().stream()
                        .map(Object::toString)
                        .sorted()
                        .collect(Collectors.toList()),
                broadcastStateValues);
    }

    private String takeSavepoint(JobGraph jobGraph) throws Exception {
        SavepointSource.initializeForTest();

        ClusterClient<?> client = miniClusterResource.getClusterClient();
        JobID jobId = jobGraph.getJobID();

        Deadline deadline = Deadline.fromNow(Duration.ofMinutes(5));

        String dirPath = getTempDirPath(new AbstractID().toHexString());

        try {
            JobID jobID = client.submitJob(jobGraph).get();

            boolean finished = false;
            while (deadline.hasTimeLeft()) {
                if (SavepointSource.isFinished()) {
                    finished = true;

                    break;
                }

                try {
                    Thread.sleep(2L);
                } catch (InterruptedException ignored) {
                    Thread.currentThread().interrupt();
                }
            }

            if (!finished) {
                Assert.fail("Failed to initialize state within deadline");
            }

            CompletableFuture<String> path = client.triggerSavepoint(jobID, dirPath);
            return path.get(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS);
        } finally {
            client.cancel(jobId).get();
        }
    }

    private static class SavepointSource implements SourceFunction<Integer> {
        private static volatile boolean finished;

        private volatile boolean running = true;

        private static final Integer[] elements = {1, 2, 3};

        @Override
        public void run(SourceContext<Integer> ctx) {
            synchronized (ctx.getCheckpointLock()) {
                for (Integer element : elements) {
                    ctx.collect(element);
                }

                finished = true;
            }

            while (running) {
                try {
                    Thread.sleep(100);
                } catch (InterruptedException e) {
                    // ignore
                }
            }
        }

        @Override
        public void cancel() {
            running = false;
        }

        private static void initializeForTest() {
            finished = false;
        }

        private static boolean isFinished() {
            return finished;
        }

        private static List<Integer> getElements() {
            return Arrays.asList(elements);
        }
    }

    private static class StatefulOperator extends BroadcastProcessFunction<Integer, Integer, Void>
            implements CheckpointedFunction {

        private final ListStateDescriptor<Integer> list;
        private final ListStateDescriptor<Integer> union;
        private final MapStateDescriptor<Integer, String> broadcast;

        private List<Integer> elements;

        private ListState<Integer> listState;

        private ListState<Integer> unionState;

        private StatefulOperator(
                ListStateDescriptor<Integer> list,
                ListStateDescriptor<Integer> union,
                MapStateDescriptor<Integer, String> broadcast) {

            this.list = list;
            this.union = union;
            this.broadcast = broadcast;
        }

        @Override
        public void open(Configuration parameters) {
            elements = new ArrayList<>();
        }

        @Override
        public void processElement(Integer value, ReadOnlyContext ctx, Collector<Void> out) {
            elements.add(value);
        }

        @Override
        public void processBroadcastElement(Integer value, Context ctx, Collector<Void> out)
                throws Exception {
            ctx.getBroadcastState(broadcast).put(value, value.toString());
        }

        @Override
        public void snapshotState(FunctionSnapshotContext context) throws Exception {
            listState.clear();

            listState.addAll(elements);

            unionState.clear();

            unionState.addAll(elements);
        }

        @Override
        public void initializeState(FunctionInitializationContext context) throws Exception {
            listState = context.getOperatorStateStore().getListState(list);

            unionState = context.getOperatorStateStore().getUnionListState(union);
        }
    }
}
